Global View的概念和实现|OneFlow学习笔记
import oneflow as flow
x = flow.placement(type="cuda", ranks=[[0, 1, 2, 3], [4, 5, 6, 7]])
<class 'oneflow._oneflow_internal.placement'>
oneflow.placement(type="cuda", ranks=[[0, 1, 2, 3], [4, 5, 6, 7]])
type:表示设备类型,目前只支持CPU和CUDA
ranks:一个Python list,用于表示device的排布信息,ranks可以是一维至多维的,其shape表示了设备的排布信息(hierarchy)。上述ranks表示Tensor存放在集群中的2个节点中,其中节点1中使用设备0~3,节点2中使用设备4~7。
placement = oneflow._oneflow_internal.placement
ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<Symbol<ParallelDesc>, std::shared_ptr<Symbol<ParallelDesc>>>(m, "placement",
py::dynamic_attr())
.def(...)
.def(py::init([](const std::string& type, const py::object& ranks) {
return
PlacementSymbolExportUtil::CreateParallelDescSymbol(type, ranks).GetOrThrow();
}),
py::arg("type"), py::arg("ranks"))
// create Symbol<ParallelDesc> object through given device_type and ranks parameters
static Maybe<Symbol<ParallelDesc>> CreateParallelDescSymbol(const std::string& type,
const py::object& ranks) {
auto* obj = reinterpret_cast<PyArrayObject*>(PyArray_FromAny(
ranks.ptr(), nullptr, 0, 0, NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSURECOPY, nullptr));
if (!obj) { return Error::RuntimeError() << "placement ranks must be int64 array."; }
const auto& shape = JUST(GetRanksShape(obj));
const auto& formated_machine_device_ids = JUST(ParseAndFormatRanks(obj));
return SymbolOf(*JUST(CreateParallelDesc(type, *formated_machine_device_ids, shape)));
}
static Maybe<ParallelDesc> CreateParallelDesc(
const std::string& type, const std::vector<std::string>& formated_machine_device_ids,
const std::shared_ptr<Shape>& hierarchy_shape) {
JUST(CheckDeviceTag(type));
auto parallel_conf = JUST(MakeParallelConf(type, formated_machine_device_ids, hierarchy_shape));
std::shared_ptr<ParallelDesc> parallel_desc;
JUST(PhysicalRun([¶llel_desc, ¶llel_conf](InstructionsBuilder* builder) -> Maybe<void> {
parallel_desc = JUST(builder->GetParallelDescSymbol(parallel_conf));
return Maybe<void>::Ok();
}));
return parallel_desc;
}
Maybe<cfg::ParallelConf> MakeParallelConf(const std::string& device_tag,
const std::vector<std::string>& machine_device_ids,
const std::shared_ptr<Shape>& hierarchy) {
std::shared_ptr<cfg::ParallelConf> parallel_conf = std::make_shared<cfg::ParallelConf>();
parallel_conf->set_device_tag(device_tag);
for (const std::string& machine_device_id : machine_device_ids) {
size_t pos = machine_device_id.find(':');
CHECK_NE_OR_RETURN(pos, std::string::npos) << "device_name: " << machine_device_id;
std::string machine_id = machine_device_id.substr(0, pos);
CHECK_OR_RETURN(
(IsStrInt(machine_id) || (machine_id[0] == '@' && IsStrInt(machine_id.substr(1)))))
<< " machine_id: " << machine_id;
std::string device_id = machine_device_id.substr(pos + 1);
size_t minus_pos = device_id.rfind('-');
if (minus_pos == std::string::npos) {
CHECK_OR_RETURN(IsStrInt(device_id));
} else {
std::string min_id = device_id.substr(0, minus_pos);
CHECK_OR_RETURN(IsStrInt(min_id));
std::string max_id = device_id.substr(minus_pos + 1);
CHECK_OR_RETURN(IsStrInt(max_id));
}
parallel_conf->add_device_name(machine_device_id);
if (hierarchy) {
ShapeProto proto;
hierarchy->ToProto(&proto);
parallel_conf->mutable_hierarchy()->CopyFrom(cfg::ShapeProto(proto));
}
}
return parallel_conf;
}
Maybe<ParallelDesc> InstructionsBuilder::GetParallelDescSymbol(
const std::shared_ptr<cfg::ParallelConf>& parallel_conf) {
int64_t symbol_id = JUST(FindOrCreateSymbolId(*parallel_conf));
return Global<symbol::Storage<ParallelDesc>>::Get()->MaybeGetPtr(symbol_id);
}
oneflow/core/job/placement.proto
build/oneflow/core/job/placement.pb.h
build/oneflow/core/job/placement.pb.cc
build/of_cfg_proto_python/oneflow/core/job/placement_pb2.py
build/oneflow/core/job/placement.cfg.h
build/oneflow/core/job/placement.cfg.cpp
build/oneflow/core/job/placement.cfg.pybind.cpp
class ParallelDesc final {
...
...
Optional<int64_t> symbol_id_;
DeviceType device_type_;
ParallelConf parallel_conf_;
std::shared_ptr<Shape> hierarchy_;
std::vector<int64_t> sorted_machine_ids_;
std::shared_ptr<HashMap<int64_t, std::shared_ptr<std::vector<int64_t>>>>
machine_id2sorted_dev_phy_ids_;
int64_t parallel_num_;
int64_t device_num_of_each_machine_;
std::vector<int64_t> parallel_id2machine_id_;
std::vector<int64_t> parallel_id2device_id_;
HashMap<int64_t, HashMap<int64_t, int64_t>> machine_id2device_id2parallel_id_;
// TODO(lixinqi): merge cfg_parallel_conf_ and parallel_conf_ after cfg::ParallelConf taken as the
// constructor argument
std::shared_ptr<cfg::ParallelConf> cfg_parallel_conf_;
// cached result of ContainingMachineId(GlobalProcessCtx::Rank()) for performace optimization.
bool containing_current_rank_;
};
Maybe<void> ParallelDesc::MaybeInit(const ParallelConf& user_conf) {
parallel_conf_ = user_conf;
device_type_ = DeviceType::kInvalidDevice;
const std::string& device_tag = parallel_conf_.device_tag();
DeviceType device_type = JUST(DeviceType4DeviceTag(device_tag));
CHECK_OR_RETURN(device_type_ == DeviceType::kInvalidDevice || device_type_ == device_type);
device_type_ = device_type;
machine_id2sorted_dev_phy_ids_ =
std::make_shared<HashMap<int64_t, std::shared_ptr<std::vector<int64_t>>>>();
for (const std::string& device_name : parallel_conf_.device_name()) {
if (device_name[0] == '@') {
JUST(SetMachineIdAndDeviceIdsByParsingDeviceName(device_name.substr(1), 1));
} else {
JUST(SetMachineIdAndDeviceIdsByParsingDeviceName(device_name,
GlobalProcessCtx::NumOfProcessPerNode()));
}
}
containing_current_rank_ = machine_id2sorted_dev_phy_ids_->count(GlobalProcessCtx::Rank()) > 0;
ClearUp();
JUST(SanityCheck());
return Maybe<void>::Ok();
}
Split:表示把数据按照指定的维度进行切分,被切分出的数据块会被分发到前面Placement指定的各个物理设备中去
Broadcast:表示把整份数据广播到前面Placement指定的各个物理设备中去
Partial:表示前面Placement指定的各个物理设备中所存的数据不是最终的运算结果,需要对各个物理设备上的数据进行Elementwise的add/min/max等操作,才能得到最终的结果
import oneflow as flow
s=flow.sbp.split(1)
b=flow.sbp.broadcast
p=flow.sbp.partial_sum
<class 'oneflow._oneflow_internal.sbp.sbp'>
<class 'oneflow._oneflow_internal.sbp.sbp'>
<class 'oneflow._oneflow_internal.sbp.sbp'>
oneflow.sbp.split(axis=1)
oneflow.sbp.broadcast
oneflow.sbp.partial_sum
from . import sbp
import oneflow
from oneflow.framework.distribute import split_sbp as split
import oneflow._oneflow_internal
sbp = oneflow._oneflow_internal.sbp.sbp
broadcast = oneflow._oneflow_internal.sbp.broadcast()
partial_sum = oneflow._oneflow_internal.sbp.partial_sum()
# 其中split_sbp的定义如下
def split_sbp(axis: int) -> oneflow._oneflow_internal.sbp.sbp:
assert type(axis) is int
return oneflow._oneflow_internal.sbp.split(axis)
ONEFLOW_API_PYBIND11_MODULE("sbp", m) {
m.attr("max_split_axis") = kMaxSplitAxis;
py::class_<Symbol<SbpParallel>, std::shared_ptr<Symbol<SbpParallel>>>(m, "sbp",
py::dynamic_attr())
.def("__str__", &api::SbpToString)
...
...
m.def("split", GetSplitSbpParallel, py::arg("axis"));
m.def("broadcast", &GetBroadcastSbpParallel);
m.def("partial_sum", &GetPartialSumSbpParallel);
}
oneflow/core/job/sbp_parallel.proto
build/oneflow/core/job/sbp_parallel.pb.h
build/oneflow/core/job/sbp_parallel.pb.cc
build/of_cfg_proto_python/oneflow/core/job/sbp_parallel_pb2.py
build/oneflow/core/job/sbp_parallel.cfg.h
build/oneflow/core/job/sbp_parallel.cfg.cpp
build/oneflow/core/job/sbp_parallel.cfg.pybind.cpp
message SplitParallel { required int64 axis = 1; }
message BroadcastParallel { }
message PartialSumParallel { }
message SbpParallel {
oneof parallel_type {
SplitParallel split_parallel = 1;
BroadcastParallel broadcast_parallel = 2;
PartialSumParallel partial_sum_parallel = 3;
}
}
message SbpSignature { map<string, SbpParallel> bn_in_op2sbp_parallel = 1; }
message NdSbp { repeated SbpParallel sbp_parallel = 1; }
message NdSbpSignature { map<string, NdSbp> bn_in_op2nd_sbp = 1; }
message SbpSignatureList { repeated SbpSignature sbp_signature = 1; }
Maybe<Symbol<SbpParallel>> GetSplitSbpParallel(int axis) {
CHECK_LT_OR_RETURN(axis, kMaxSplitAxis);
static std::vector<Symbol<SbpParallel>> split_sbp_sym_list =
*JUST(MakeSplitSbpParallelList(kMaxSplitAxis));
return split_sbp_sym_list.at(axis);
}
Maybe<Symbol<SbpParallel>> GetBroadcastSbpParallel() {
static Symbol<SbpParallel> broadcast_sbp = JUST(MakeBroadcastSbpParallel());
return broadcast_sbp;
}
Maybe<Symbol<SbpParallel>> GetPartialSumSbpParallel() {
static Symbol<SbpParallel> partial_sum_sbp = JUST(MakePartialSumSbpParallel());
return partial_sum_sbp;
}
Maybe<Symbol<SbpParallel>> MakeSplitSbpParallel(int axis) {
CHECK_LT_OR_RETURN(axis, kMaxSplitAxis);
SbpParallel split_sbp_parallel;
split_sbp_parallel.mutable_split_parallel()->set_axis(axis);
return SymbolOf(split_sbp_parallel);
}
Maybe<Symbol<SbpParallel>> MakeBroadcastSbpParallel() {
SbpParallel broadcast_sbp;
broadcast_sbp.mutable_broadcast_parallel();
return SymbolOf(broadcast_sbp);
}
Maybe<Symbol<SbpParallel>> MakePartialSumSbpParallel() {
SbpParallel partial_sum_sbp;
partial_sum_sbp.mutable_partial_sum_parallel();
return SymbolOf(partial_sum_sbp);
}
std::unordered_map<HashEqTraitPtr<const T>, std::shared_ptr<const T>>;
class TensorStorage {
...
size_t blob_bytes_;
std::unique_ptr<char, std::function<void(char*)>> blob_dptr_;
std::unique_ptr<MemoryAllocator> non_pod_allocator_;
Optional<Symbol<Stream>> producer_stream_;
Optional<Symbol<Stream>> last_used_stream_;
std::vector<std::function<void()>> storage_delete_hooks_;
};